# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: MIT-0
import sys
import numpy as np
from numpy import random as rand
import pandas as pd
from multiprocessing import Pool
import functools, multiprocessing
from datetime import datetime
import math
from collections import OrderedDict
#import pymc as pm
import collections
import time
from scipy.stats import multivariate_normal
from scipy.stats import norm


def get_XYdesign(outers,X,X_all,iterations, threshold, warm_start, sigmas=None):
    #Get the XY design using the Franke-Wolfe Algorithm - like the Frank Wolfe algorithm below, but this constructs the Y-matrix for XY-Allocation
    #instead of receiving it as an argument

    
    lambda_vec = warm_start.copy() #initialize at warm start
    arms_left = X.shape[0] #how many arms are in the optimal arm
    old_y_max_val = 1 #initialize
    if sigmas is None:
        sigmas = np.ones(len(lambda_vec))
    
    sigmas_diag_inv = np.linalg.inv(np.diag(sigmas))

    for k in range(1,iterations):
                               
        A_lambda = np.sum(outers*lambda_vec[:,np.newaxis, np.newaxis]/sigmas[:,np.newaxis, np.newaxis], axis=0) #calculate lambda distribution induced design
                               
        if np.linalg.det(A_lambda) == 0: #if matrix is singular than use pseudo-inverse
            #print("singular")
            cov_A = np.linalg.pinv(A_lambda)
        else:
            cov_A = np.linalg.inv(A_lambda)
        
        X_A = X @ cov_A @ X.T
        #(X[i,] - X[j,]) @ cov_A @ (X[i,] - X[j,])
        #X_A[i,i] + X_A[j,j] - 2*X_A[i,j]
        pred_vars = np.add.outer(np.diag(X_A), np.diag(X_A)) - 2*X_A #way to do the commented computation efficiently         
        indcs = np.unravel_index(np.argmax(pred_vars), pred_vars.shape) #find max predictive uncertainty
        y_max_val = np.max(pred_vars) #value
        max_y = X[indcs[0]] - X[indcs[1]] #difference between potentially optimal arms with the most uncertainty
        #print(max_y)
        lambda_derivative = -(max_y.T @ cov_A @ X_all.T @ np.sqrt(sigmas_diag_inv))**2 #calculate derivative\
        
        #Frank-Wolfe update
        alpha = 2/(k+2) #step size
        min_lambda_derivative_index = np.argmin(lambda_derivative)
        lambda_vec -= alpha*lambda_vec
        lambda_vec[min_lambda_derivative_index] +=  alpha
        
        #print(y_max_val)

        if y_max_val == 0 or abs((old_y_max_val - y_max_val)/old_y_max_val) < threshold: #threshold criterion for stopping
            #print("triggered")
            break
        old_y_max_val = y_max_val #storage for threshold criterion
    return cov_A, y_max_val, lambda_vec



def get_oracle(outers,X,Z,iterations, true_best_index, theta,threshold, warm_start, sigmas=None):
    #Generate true lambda distribution
    #establish Y matrix of differences for the max arm
    true_best = Z[true_best_index]
    Z_diff = np.delete(Z, true_best_index, 0)
    norms = ((true_best - Z_diff) @ theta).reshape(-1, 1)
    Y = (true_best - Z_diff)/norms
    #print("Sigmas")
    #print(sigmas)
    return(FrankWolfe_XY(outers,X,Y,iterations, threshold, warm_start, sigmas))

    
def FrankWolfe_XY(outers, X, Y, iterations, threshold, warm_start, sigmas=None):
    #FrankWolfe optimization that takes the Y matrix of differences
    
    old_y_max_val = 1
    lambda_vec = warm_start
    if sigmas is None:
        sigmas = np.ones(len(lambda_vec))
    
    sigmas_diag_inv = np.linalg.inv(np.diag(sigmas))
    
    for k in range(1,iterations):
        #compute design
        A_lambda = np.sum(outers*lambda_vec[:,np.newaxis, np.newaxis]/sigmas[:,np.newaxis, np.newaxis], axis=0)
        
        #compute pseudo-inverse if singular
        if np.linalg.det(A_lambda) == 0:
            #print("singular")
            cov_A = np.linalg.pinv(A_lambda)
        else:
            cov_A = np.linalg.inv(A_lambda)
            
        #determine max
        diag_arg = np.diag(Y @cov_A @ Y.T)
        y_max = Y[np.argmax(diag_arg)] #index of max predictive uncertainty for differences
        y_max_val = np.max(diag_arg) #value of max predictive uncertainty
        #y_max = Y[np.argmax(np.diag(Y @ cov_A @ Y.T))] #index of max predictive uncertainty for differences
        #y_max_val = np.max(np.diag(Y @ cov_A @ Y.T)) #value of max predictive uncertainty
        lambda_derivative = -(y_max.T @ cov_A @ X.T @ np.sqrt(sigmas_diag_inv))**2 #compute derivative 
        
        #print(y_max_val)
        
        #update lambda vector
        alpha = 2/(k+2) #step size
        min_lambda_derivative_index = np.argmin(lambda_derivative)
        
        #Frank-Wolfe update
        lambda_vec -= alpha*lambda_vec
        lambda_vec[min_lambda_derivative_index] +=  alpha
        
        if y_max_val == 0 or abs((old_y_max_val - y_max_val)/old_y_max_val) < threshold: #threshold criterion for stopping 
            break
        old_y_max_val = y_max_val #storage for threshold criterion
    #print(y_max_val)
    return cov_A, y_max_val, lambda_vec
